{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bringing contextual word representations into your models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Christopher Potts\"\n",
    "__version__ = \"CS224u, Stanford, Spring 2022\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Overview](#Overview)\n",
    "1. [General set-up](#General-set-up)\n",
    "1. [Hugging Face BERT models and tokenizers](#Hugging-Face-BERT-models-and-tokenizers)\n",
    "1. [BERT featurization with Hugging Face](#BERT-featurization-with-Hugging-Face)\n",
    "    1. [Simple feed-forward experiment](#Simple-feed-forward-experiment)\n",
    "    1. [A feed-forward experiment with the sst module](#A-feed-forward-experiment-with-the-sst-module)\n",
    "    1. [An RNN experiment with the sst module](#An-RNN-experiment-with-the-sst-module)\n",
    "1. [BERT fine-tuning with Hugging Face](#BERT-fine-tuning-with-Hugging-Face)\n",
    "    1. [HfBertClassifier](#HfBertClassifier)\n",
    "    1. [HfBertClassifier experiment](#HfBertClassifier-experiment)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "This notebook provides a basic introduction to using pre-trained [BERT](https://github.com/google-research/bert) representations with the Hugging Face library. It is meant as a practical companion to our lecture on contextual word representations. The goal of this notebook is just to help you use these representations in your own work.\n",
    "\n",
    "If you haven't already, I encourage you to review the notebook [vsm_03_contextualreps.ipynb](vsm_03_contextualreps.ipynb) before working with this one. That notebook covers the fundamentals of these models; this one dives into the details more quickly.\n",
    "\n",
    "A number of the experiments in this notebook are resource-intensive. I've included timing information for the expensive steps, to give you a sense for how long things are likely to take. I ran this notebook on a laptop with a single NVIDIA RTX 2080 GPU. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General set-up\n",
    "\n",
    "The following are requirements that you'll already have met if you've been working in this repository. As you can see, we'll use the [Stanford Sentiment Treebank](sst_01_overview.ipynb) for illustrations, and we'll try out a few different deep learning models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from sklearn.metrics import classification_report\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import transformers\n",
    "from transformers import BertModel, BertTokenizer\n",
    "\n",
    "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
    "from torch_rnn_classifier import TorchRNNModel\n",
    "from torch_rnn_classifier import TorchRNNClassifier\n",
    "from torch_rnn_classifier import TorchRNNClassifierModel\n",
    "from torch_rnn_classifier import TorchRNNClassifier\n",
    "import sst\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.fix_random_seeds()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "SST_HOME = os.path.join(\"data\", \"sentiment\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `transformers` library does a lot of logging. To avoid ending up with a cluttered notebook, I am changing the logging level. You might want to skip this as you scale up to building production systems, since the logging is very good – it gives you a lot of insights into what the models and code are doing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformers.logging.set_verbosity_error()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hugging Face BERT models and tokenizers\n",
    "\n",
    "We'll illustrate with the BERT-base cased model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_name = 'bert-base-cased'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are lots other options for pretrained weights. See [this Hugging Face directory](https://huggingface.co/models)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we specify a tokenizer and a model that match both each other and our choice of pretrained weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_tokenizer = BertTokenizer.from_pretrained(weights_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_model = BertModel.from_pretrained(weights_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For modeling (as opposed to creating static representations), we will mostly process examples in batches – generally very small ones, as these models consume _a lot_ of memory. Here's a small batch of texts to use as the starting point for illustrations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_texts = [\n",
    "    \"Encode sentence 1. [SEP] And sentence 2!\",\n",
    "    \"Bert knows Snuffleupagus\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will often need to pad (and perhaps truncate) token lists so that we can work with fixed-dimensional tensors: The `batch_encode_plus` has a lot of options for doing this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_ids = bert_tokenizer.batch_encode_plus(\n",
    "    example_texts,\n",
    "    add_special_tokens=True,\n",
    "    return_attention_mask=True,\n",
    "    padding='longest')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_ids.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `token_type_ids` is used for multi-text inputs like NLI. The `'input_ids'` field gives the indices for each of the two examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[101, 13832, 13775, 5650, 122, 119, 102, 1262, 5650, 123, 106, 102],\n",
       " [101, 15035, 3520, 156, 14787, 13327, 4455, 28026, 1116, 102, 0, 0]]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_ids['input_ids']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the final two tokens of the second example are pad tokens.\n",
    "\n",
    "For fine-tuning, we want to avoid attending to padded tokens. The `'attention_mask'` captures the needed mask, which we'll be able to feed directly to the pretrained BERT model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_ids['attention_mask']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can run these indices and masks through the pretrained model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_example = torch.tensor(example_ids['input_ids'])\n",
    "X_example_mask = torch.tensor(example_ids['attention_mask'])\n",
    "\n",
    "with torch.no_grad():\n",
    "    reps = bert_model(X_example, attention_mask=X_example_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hugging Face BERT models create a special `pooler_output` representation that is the final representation above the [CLS] extended with a single layer of parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 768])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reps.pooler_output.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have two examples, each representented by a single vector of dimension 768, which is $d_{model}$ for BERT base using the notation from [the original Transformers paper](https://arxiv.org/abs/1706.03762). This is an easy basis for fine-tuning, as we will see.\n",
    "\n",
    "We can also access the final output for each state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 12, 768])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reps.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we have 2 examples, each padded to the length of the longer one (12), and each of those representations has dimension 768. These representations can be used for sequence modeling, or pooled somehow for simple classifiers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Those are all the essential ingredients for working with these parameters in Hugging Face. Of course, the library has a lot of other functionality, but the above suffices to featurize and to fine-tune."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BERT featurization with Hugging Face\n",
    "\n",
    "To start, we'll use the Hugging Face interfaces just to featurize examples to create inputs to a separate model. In this setting, the BERT parameters are frozen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_phi(text):\n",
    "    input_ids = bert_tokenizer.encode(text, add_special_tokens=True)\n",
    "    X = torch.tensor([input_ids])\n",
    "    with torch.no_grad():\n",
    "        reps = bert_model(X)\n",
    "        return reps.last_hidden_state.squeeze(0).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple feed-forward experiment\n",
    "\n",
    "For a simple feed-forward experiment, we can get the representation of the `[CLS]` tokens and use them as the inputs to a shallow neural network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_classifier_phi(text):\n",
    "    reps = bert_phi(text)\n",
    "    #return reps.mean(axis=0)  # Another good, easy option.\n",
    "    return reps[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we read in the SST train and dev splits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = sst.train_reader(SST_HOME)\n",
    "\n",
    "dev = sst.dev_reader(SST_HOME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the input/output pairs out into separate lists:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_str_train = train.sentence.values\n",
    "y_train = train.label.values\n",
    "\n",
    "X_str_dev = dev.sentence.values\n",
    "y_dev = dev.label.values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the next step, we featurize all of the examples. These steps are likely to be the slowest in these experiments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 32min 44s, sys: 52.8 s, total: 33min 37s\n",
      "Wall time: 8min 24s\n"
     ]
    }
   ],
   "source": [
    "%time X_train = [bert_classifier_phi(text) for text in X_str_train]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4min 14s, sys: 7.2 s, total: 4min 22s\n",
      "Wall time: 1min 5s\n"
     ]
    }
   ],
   "source": [
    "%time X_dev = [bert_classifier_phi(text) for text in X_str_dev]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that all the examples are featurized, we can fit a model and evaluate it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TorchShallowNeuralClassifier(\n",
    "    early_stopping=True,\n",
    "    hidden_dim=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 45. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.156181752681732"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 21.3 s, sys: 2.56 s, total: 23.9 s\n",
      "Wall time: 8.85 s\n"
     ]
    }
   ],
   "source": [
    "%time _ = model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model.predict(X_dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.696     0.787     0.739       428\n",
      "     neutral      0.342     0.279     0.308       229\n",
      "    positive      0.756     0.732     0.744       444\n",
      "\n",
      "    accuracy                          0.659      1101\n",
      "   macro avg      0.598     0.600     0.597      1101\n",
      "weighted avg      0.647     0.659     0.651      1101\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_dev, preds, digits=3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A feed-forward experiment with the sst module\n",
    "\n",
    "It is straightforward to conduct experiments like the above using `sst.experiment`, which will enable you to do a wider range of experiments without writing or copy-pasting a lot of code. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_shallow_network(X, y):\n",
    "    mod = TorchShallowNeuralClassifier(\n",
    "        hidden_dim=300,\n",
    "        early_stopping=True)\n",
    "    mod.fit(X, y)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 39. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.242022633552551"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.701     0.806     0.750       428\n",
      "     neutral      0.435     0.162     0.236       229\n",
      "    positive      0.714     0.842     0.773       444\n",
      "\n",
      "    accuracy                          0.687      1101\n",
      "   macro avg      0.617     0.603     0.586      1101\n",
      "weighted avg      0.651     0.687     0.652      1101\n",
      "\n",
      "CPU times: user 38min 14s, sys: 1min 2s, total: 39min 17s\n",
      "Wall time: 9min 49s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = sst.experiment(\n",
    "    sst.train_reader(SST_HOME),\n",
    "    bert_classifier_phi,\n",
    "    fit_shallow_network,\n",
    "    assess_dataframes=sst.dev_reader(SST_HOME),\n",
    "    vectorize=False)  # Pass in the BERT reps directly!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An RNN experiment with the sst module\n",
    "\n",
    "We can also use BERT representations as the input to an RNN. There is just one key change from how we used these models before:\n",
    "\n",
    "* Previously, we would feed in lists of tokens, and they would be converted to indices into a fixed embedding space. This presumes that all words have the same representation no matter what their context is. \n",
    "\n",
    "* With BERT, we skip the embedding entirely and just feed in lists of BERT vectors, which means that the same word can be represented in different ways.\n",
    "\n",
    "`TorchRNNClassifier` supports this via `use_embedding=False`. In turn, you needn't supply a vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_rnn(X, y):\n",
    "    mod = TorchRNNClassifier(\n",
    "        vocab=[],\n",
    "        early_stopping=True,\n",
    "        use_embedding=False)  # Pass in the BERT hidden states directly!\n",
    "    mod.fit(X, y)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 32. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.7171962857246399"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.702     0.776     0.737       428\n",
      "     neutral      0.351     0.236     0.282       229\n",
      "    positive      0.747     0.797     0.771       444\n",
      "\n",
      "    accuracy                          0.672      1101\n",
      "   macro avg      0.600     0.603     0.597      1101\n",
      "weighted avg      0.647     0.672     0.656      1101\n",
      "\n",
      "CPU times: user 38min 45s, sys: 1min 39s, total: 40min 24s\n",
      "Wall time: 10min 6s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = sst.experiment(\n",
    "    sst.train_reader(SST_HOME),\n",
    "    bert_phi,\n",
    "    fit_rnn,\n",
    "    assess_dataframes=sst.dev_reader(SST_HOME),\n",
    "    vectorize=False)  # Pass in the BERT hidden states directly!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BERT fine-tuning with Hugging Face\n",
    "\n",
    "The above experiments are quite successful – BERT gives us a reliable boost compared to other methods we've explored for the SST task. However, we might expect to do even better if we fine-tune the BERT parameters as part of fitting our SST classifier. To do that, we need to incorporate the Hugging Face BERT model into our classifier. This too is quite straightforward."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HfBertClassifier\n",
    "\n",
    "The most important step is to create an `nn.Module` subclass that has, for its parameters, both the BERT model and parameters for our own classifier. Here we define a very simple fine-tuning set-up in which some layers built on top of the output corresponding to `[CLS]` are used as the basis for the SST classifier:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HfBertClassifierModel(nn.Module):\n",
    "    def __init__(self, n_classes, weights_name='bert-base-cased'):\n",
    "        super().__init__()\n",
    "        self.n_classes = n_classes\n",
    "        self.weights_name = weights_name\n",
    "        self.bert = BertModel.from_pretrained(self.weights_name)\n",
    "        self.bert.train()\n",
    "        self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim\n",
    "        # The only new parameters -- the classifier:\n",
    "        self.classifier_layer = nn.Linear(\n",
    "            self.hidden_dim, self.n_classes)\n",
    "\n",
    "    def forward(self, indices, mask):\n",
    "        reps = self.bert(\n",
    "            indices, attention_mask=mask)\n",
    "        return self.classifier_layer(reps.pooler_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, `self.bert` does the heavy-lifting: it reads in all the pretrained BERT parameters, and I've specified `self.bert.train()` just to make sure that these parameters can be updated during our training process. \n",
    "\n",
    "In `forward`, `self.bert` is used to process inputs, and then `pooler_output` is fed into `self.classifier_layer`. Hugging Face has already added a layer on top of the actual output for `[CLS]`, so we can specify the model as\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "[h_{1}, \\ldots, h_{n}] &= \\text{BERT}([x_{1}, \\ldots, x_{n}]) \\\\\n",
    "h &= \\tanh(h_{1}W_{hh} + b_{h}) \\\\\n",
    "y &= \\textbf{softmax}(hW_{hy} + b_{y})\n",
    "\\end{align}$$\n",
    "\n",
    "for a tokenized input sequence $[x_{1}, \\ldots, x_{n}]$. \n",
    "\n",
    "The Hugging Face documentation somewhat amusingly says, of `pooler_output`,\n",
    "\n",
    "> This output is usually _not_ a good summary of the semantic content of the input, you're often better with averaging or pooling the sequence of hidden-states for the whole input sequence.\n",
    "\n",
    "which is entirely reasonable, but it will require more resources, so we'll do the simpler thing here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the training and prediction interface, we can subclass `TorchShallowNeuralClassifier` so that we don't have to write any of our own data-handling, training, or prediction code. The central changes are using `HfBertClassifierModel` in `build_graph` and processing the data with `batch_encode_plus`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HfBertClassifier(TorchShallowNeuralClassifier):\n",
    "    def __init__(self, weights_name, *args, **kwargs):\n",
    "        self.weights_name = weights_name\n",
    "        self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.params += ['weights_name']\n",
    "\n",
    "    def build_graph(self):\n",
    "        return HfBertClassifierModel(self.n_classes_, self.weights_name)\n",
    "\n",
    "    def build_dataset(self, X, y=None):\n",
    "        data = self.tokenizer.batch_encode_plus(\n",
    "            X,\n",
    "            max_length=None,\n",
    "            add_special_tokens=True,\n",
    "            padding='longest',\n",
    "            return_attention_mask=True)\n",
    "        indices = torch.tensor(data['input_ids'])\n",
    "        mask = torch.tensor(data['attention_mask'])\n",
    "        if y is None:\n",
    "            dataset = torch.utils.data.TensorDataset(indices, mask)\n",
    "        else:\n",
    "            self.classes_ = sorted(set(y))\n",
    "            self.n_classes_ = len(self.classes_)\n",
    "            class2index = dict(zip(self.classes_, range(self.n_classes_)))\n",
    "            y = [class2index[label] for label in y]\n",
    "            y = torch.tensor(y)\n",
    "            dataset = torch.utils.data.TensorDataset(indices, mask, y)\n",
    "        return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HfBertClassifier experiment\n",
    "\n",
    "That's it! Let's see how we do on the SST binary, root-only problem. Because fine-tuning is expensive, we'll conduct a modest hyperparameter search and run the model for just one epoch per setting evaluation, as we did when [assessing NLI models](nli_02_models.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_fine_tune_phi(text):\n",
    "    return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_hf_bert_classifier_with_hyperparameter_search(X, y):\n",
    "    basemod = HfBertClassifier(\n",
    "        weights_name='bert-base-cased',\n",
    "        batch_size=8,  # Small batches to avoid memory overload.\n",
    "        max_iter=1,  # We'll search based on 1 iteration for efficiency.\n",
    "        n_iter_no_change=5,   # Early-stopping params are for the\n",
    "        early_stopping=True)  # final evaluation.\n",
    "\n",
    "    param_grid = {\n",
    "        'gradient_accumulation_steps': [1, 4, 8],\n",
    "        'eta': [0.00005, 0.0001, 0.001],\n",
    "        'hidden_dim': [100, 200, 300]}\n",
    "\n",
    "    bestmod = utils.fit_classifier_with_hyperparameter_search(\n",
    "        X, y, basemod, cv=3, param_grid=param_grid)\n",
    "\n",
    "    return bestmod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1 of 1; error is 184.64238105341792"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best params: {'eta': 5e-05, 'gradient_accumulation_steps': 4, 'hidden_dim': 200}\n",
      "Best score: 0.587\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.686     0.930     0.790       428\n",
      "     neutral      0.514     0.079     0.136       229\n",
      "    positive      0.763     0.836     0.798       444\n",
      "\n",
      "    accuracy                          0.715      1101\n",
      "   macro avg      0.655     0.615     0.575      1101\n",
      "weighted avg      0.682     0.715     0.657      1101\n",
      "\n",
      "CPU times: user 1h 27min 12s, sys: 11min 18s, total: 1h 38min 31s\n",
      "Wall time: 1h 37min 44s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "bert_classifier_xval = sst.experiment(\n",
    "    sst.train_reader(SST_HOME),\n",
    "    bert_fine_tune_phi,\n",
    "    fit_hf_bert_classifier_with_hyperparameter_search,\n",
    "    assess_dataframes=sst.dev_reader(SST_HOME),\n",
    "    vectorize=False)  # Pass in the BERT hidden state directly!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now on to the final test-set evaluation, using the best model from above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_bert_classifier = bert_classifier_xval['model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove the rest of the experiment results to clear out some memory:\n",
    "del bert_classifier_xval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_optimized_hf_bert_classifier(X, y):\n",
    "    optimized_bert_classifier.max_iter = 1000\n",
    "    optimized_bert_classifier.fit(X, y)\n",
    "    return optimized_bert_classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = sst.sentiment_reader(\n",
    "    os.path.join(SST_HOME, \"sst3-test-labeled.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 9. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 11.503188711278199"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.816     0.754     0.784       912\n",
      "     neutral      0.332     0.501     0.400       389\n",
      "    positive      0.881     0.756     0.813       909\n",
      "\n",
      "    accuracy                          0.710      2210\n",
      "   macro avg      0.676     0.670     0.666      2210\n",
      "weighted avg      0.758     0.710     0.728      2210\n",
      "\n",
      "CPU times: user 9min 54s, sys: 1min 22s, total: 11min 17s\n",
      "Wall time: 11min 16s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = sst.experiment(\n",
    "    sst.train_reader(SST_HOME),\n",
    "    bert_fine_tune_phi,\n",
    "    fit_optimized_hf_bert_classifier,\n",
    "    assess_dataframes=test_df,\n",
    "    vectorize=False)  # Pass in the BERT hidden state directly!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
